#  Copyright (c) 2023, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause


from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass


@register_pass(namespace="common")
class loop_invariant_elimination(AbstractGraphPass):
    """
    When a block does not modify a block input var, eliminate that block
    input var and use the corresponding var in the outer scope. Example:

    .. code-block::

       # Before loop_invariant_elimination pass.
       # Notice that ``%b.x`` is constant through while loop iterates.
       main(%a: (1, 2, fp32),
            %b: (1, 2, fp32)) {
         block0() {
           %loop:0: (1, 2, fp32), %loop:1: (1, 2, fp32) = \
           while_loop(loop_vars=(%a, %b))
             loop_cond(%a.x, %b.x) {
               %cond_var: (bool) = some_op(x=%a.x, y=%b.x)
             } -> (%cond_var)
             loop_body(%a.x, %b.x) {
               %add_0: (1, 2, fp32) = add(x=%a.x, y=%b.x)
             } -> (%add_0, %b.x)
         } -> (%loop:0, %loop:1)
       }

       # After loop_invariant_elimination pass.
       main(%a: (1, 2, fp32),
            %b: (1, 2, fp32)) {
         block0() {
           %loop:1: (1, 2, fp32) = identity(x=%b)
           %loop:0: (1, 2, fp32) = \
           while_loop(loop_vars=(%a))
             loop_cond(%a.x) {
               %cond_var: (bool) = some_op(x=%a.x, y=%b)
             } -> (%cond_var)
             loop_body(%a.x) {
               %add_0: (1, 2, fp32) = add(x=%a.x, y=%b)
             } -> (%add_0)
         } -> (%loop:0, %loop:1)
       }

    where we eliminate loop invariant ``%b.x`` from ``while_loop``, which returns 1
    instead of 2 outputs. We also preserve the return var names with identity.
    """

    def apply(self, prog):
        for f in prog.functions.values():
            self._loop_invariant_elimination_block(f)

    @staticmethod
    def _detect_loop_invariants(while_op):
        block = while_op.blocks[1]  # body block
        loop_invariant_ids = []  # list of index in op.loop_vars, block.inputs
        for i, vx_in in enumerate(block.inputs):
            vx_out = block.outputs[i]  # first output is cond var.
            return_input_as_output = vx_in == vx_out
            # this block output is a var from outside of the block

            enclosing_block = while_op.enclosing_block
            output_from_outside_of_block = enclosing_block.is_var_visible_in_block(
                vx_out, upto_op=while_op
            )
            if return_input_as_output or output_from_outside_of_block:
                loop_invariant_ids.append(i)

        # TODO: All outputs that depend on only invariants are invariant. We
        # need to move computation out of while loop.
        return loop_invariant_ids

    @block_context_manager
    def _loop_invariant_elimination_block(self, block):
        # Phase 1: Find vars needed to be renamed.
        #
        # while_loop outputs need to be renamed if the output will be eliminated
        # (due to loop invariant) and is returned as block output (which would
        # change the return var name and the program interface).
        #
        # list[(v_src, v_tgt, before_op)]: will rename v_src to v_tgt before
        # before_op (a while_loop)
        output_rename = []
        for op in list(block.operations):
            for b in op.blocks:
                self._loop_invariant_elimination_block(b)

            if op.op_type != "while_loop":
                continue

            loop_invariant_ids = self._detect_loop_invariants(op)
            for i in loop_invariant_ids:
                output_rename.append((op.loop_vars[i], op.outputs[i], op))
            if len(loop_invariant_ids) > 0:
                # Avoid the following case:
                # %a, %b = while_loop(..., name="b")
                # becomes
                # %b = identity(..., name="b")
                # %a = while_loop(..., name="b")
                # (two ops with the same name -> name collision)
                op.name = op.name + "_renamed"

        # Phase 2: insert rename ops. This changes block.operations
        for v_src, v_tgt, op in output_rename:
            if v_tgt in block.outputs:
                # rename the loop output to existing block output names
                res = mb.identity(x=v_src, before_op=op, name=v_tgt.name)
                op.enclosing_block.replace_uses_of_var_after_op(
                    anchor_op=op, old_var=v_tgt, new_var=res
                )

        # Phase 3: Perform loop invariant elimination without fear!
        for op in list(block.operations):
            if op.op_type != "while_loop":
                continue
            loop_invariant_ids = self._detect_loop_invariants(op)

            # replace uses of loop_invariants with its source from outside of the
            # while_loop op.
            for i in loop_invariant_ids:
                for block in op.blocks:
                    block.replace_uses_of_var_after_op(
                        anchor_op=None, old_var=block.inputs[i], new_var=op.loop_vars[i]
                    )

            # replace block inputs
            for block in op.blocks:
                block.remove_inputs([block.inputs[i] for i in loop_invariant_ids])

            # remove invariants from while_loop loop_vars
            for i in loop_invariant_ids:
                # replace usage of while_loop outputs that we'll eliminate.
                op.enclosing_block.replace_uses_of_var_after_op(
                    anchor_op=op, old_var=op.outputs[i], new_var=op.loop_vars[i]
                )

            # Remove after replacing to ensure program is valid
            for i in loop_invariant_ids:
                op.loop_vars[i].remove_child_op(op)

            op.loop_vars = tuple(
                v for i, v in enumerate(op.loop_vars) if i not in loop_invariant_ids
            )
            op._input_vars["loop_vars"] = op.loop_vars

            # remove invariants from while_loop body_block outputs
            body_block = op.blocks[1]
            body_block.set_outputs(
                [v for i, v in enumerate(body_block.outputs) if i not in loop_invariant_ids]
            )

            # op._output_vars doesn't include cond var
            op._output_vars = [
                v for i, v in enumerate(op._output_vars) if i not in loop_invariant_ids
            ]

            # check healthy state
            op.enclosing_block.validate()
